from __future__ import print_function
import torch
import argparse


from Utils import load_data
from Utils import load_state
from Utils import load_dirpolnet
from Utils import load_optuna_setting
from Utils import norm_coord_to_abs
from Utils import NME_calc
from Utils import NME_calc_landmarkwise

from Environment import Env

from Models import Agent
from Models import Agent_relative
from Models import dirPolNet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Landmark Detection with Active Inference')

parser.add_argument('--task', type=str, default='COFW', help='which task to run (CelebA_aligned)')
parser.add_argument('--n_landmarks', type=int, default=29, help='the number of landmarks')
parser.add_argument('--log_interval', type=int, default=5, help='interval for log [batches]')

# Dataset preprocessing
parser.add_argument('--random_scale', default=False, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=False, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

# Test setting
parser.add_argument('--batch_size', type=int, default=50, help='batch size for test')
parser.add_argument('--maximum_stage', type=int, default=2, help='Maximum detection stage for each landmark')
parser.add_argument('--max_timestep', type=int, default=[30, 30, 30, 30], help='maximum number of time-steps')
parser.add_argument('--center_coord_init', default=True, help='set initial coordinate to center')

args = parser.parse_args()


def main():
    torch.cuda.empty_cache()
    _, test_loader = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    agent_leye = Agent(args.batch_size, args.maximum_stage).to(device)
    agent_reye = Agent(args.batch_size, args.maximum_stage).to(device)
    agent_others = Agent_relative(args.batch_size, args.maximum_stage).to(device)
    
    dirpolnet_leye = dirPolNet().to(device)
    dirpolnet_reye = dirPolNet().to(device)
    dirpolnet_others = dirPolNet().to(device)
    
    leye_state, reye_state, others_state = load_state()
    
    dirpolnet_leye, dirpolnet_reye, dirpolnet_others \
        = load_dirpolnet(dirpolnet_leye, dirpolnet_reye, dirpolnet_others)
    
    args.lambda_control_start, args.lambda_decrease, args.lambda_f_init, args.lambda_freq,\
        args.thr_control_start, args.thr_increase, args.thr_init, args.thr_freq,\
            args.lambda_ft_1stage, args.lambda_ft_2stage\
            = load_optuna_setting()
    
    NME_pupil_mean, NME_ocular_mean, NME_ocular_l_mean, detection_timesteps_total\
            = test(agent_leye, leye_state, dirpolnet_leye,
                   agent_reye, reye_state, dirpolnet_reye,
                   agent_others, others_state, dirpolnet_others, test_loader)
    
    detection_timesteps_mean = detection_timesteps_total.mean(1)
    
    print("NME_ocular_mean: {:.3f}.. ".format(NME_ocular_mean).ljust(15))
    
    print("NME_ocular_l_mean: \n{}.. \n".format(NME_ocular_l_mean).ljust(15))
    
    print("detection_timesteps_total_mean: \n{}.. \n".format(detection_timesteps_mean).ljust(15))



def test(agent_leye, leye_state, dirpolnet_leye,
         agent_reye, reye_state, dirpolnet_reye,
         agent_others, others_state, dirpolnet_others, test_loader) : 
    agent_leye.eval()
    dirpolnet_leye.eval()
    agent_reye.eval()
    dirpolnet_reye.eval()
    agent_others.eval()
    dirpolnet_others.eval()

    NME_pupil_samples = []
    NME_ocular_samples = []
    
    NME_pupil_l_samples = []
    NME_ocular_l_samples = []
    for l in range(args.n_landmarks) : 
        NME_pupil_l_samples.append([])
        NME_ocular_l_samples.append([])
    
    detection_timesteps_total = torch.zeros(args.n_landmarks, args.batch_size).to(device)
    
    with torch.no_grad() : 
        for i, (images, tpts, pts, center, scale) in enumerate(test_loader) :         
            inferred_landmark_coords = torch.zeros(args.n_landmarks, args.batch_size, 2).to(device)
            
            images = images.to(device)
            env = Env(images)
            
            center_init = args.center_coord_init
            coord = None
            
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [16] if stage <=2 else [0, 2, 4, 5, 8, 10, 12, 13]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_leye.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                
                agent_leye.set_prior(leye_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_leye.progress_detection(o_t, dirpolnet_leye, end_step)
                    
                    if act_to_env == None : 
                        detection_timesteps_total[l_idxs] += agent_leye.detection_time / len(test_loader)
                        inferred_leye_coords = agent_leye.landmark_coords
                        inferred_landmark_coords[l_idxs] = inferred_leye_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_leye_coords_abs = norm_coord_to_abs(inferred_leye_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_leye_coords_abs
            
            
            center_init = args.center_coord_init
            coord = None
            
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [17] if stage <=2 else [1, 3, 6, 7, 9, 11, 14, 15]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_reye.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                
                agent_reye.set_prior(reye_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_reye.progress_detection(o_t, dirpolnet_reye, end_step)
                    
                    if act_to_env == None : 
                        detection_timesteps_total[l_idxs] += agent_reye.detection_time / len(test_loader)
                        inferred_reye_coords = agent_reye.landmark_coords 
                        inferred_landmark_coords[l_idxs] = inferred_reye_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_reye_coords_abs = norm_coord_to_abs(inferred_reye_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_reye_coords_abs
            
            
            center_init = args.center_coord_init
            coord = None
            base_c = None
            
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [21] if stage <=2 else [18, 19, 20, 22, 23, 24, 25, 26, 27, 28]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_others.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                             lambda_control_start, lambda_freq, lambda_ft)
                
                agent_others.set_prior(others_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_others.progress_detection(o_t, dirpolnet_others, base_c, end_step)
                    
                    if act_to_env == None :
                        detection_timesteps_total[l_idxs] += agent_others.detection_time / len(test_loader)
                        inferred_others_coords = agent_others.landmark_coords 
                        inferred_landmark_coords[l_idxs] = inferred_others_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_others_coords_abs = norm_coord_to_abs(inferred_others_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_others_coords_abs
                
                if stage == 2 : 
                    base_c = inferred_others_coords
                 
            
            inferred_landmark_coords = inferred_landmark_coords.permute(1,0,2)
            inferred_landmark_coords_abs = norm_coord_to_abs(inferred_landmark_coords, [256,256]).view(-1,29,2)
            inferred_landmark_coords_abs = inferred_landmark_coords_abs.flip(dims=[-1])
            NME_pupil, NME_ocular = NME_calc(inferred_landmark_coords_abs, pts, center, scale)
            
            NME_pupil_samples.extend(NME_pupil.tolist())
            NME_ocular_samples.extend(NME_ocular.tolist())
            
            NME_pupil_l, NME_ocular_l = NME_calc_landmarkwise(inferred_landmark_coords_abs, pts, center, scale)
            for l in range(args.n_landmarks) : 
                NME_pupil_l_samples[l].extend(NME_pupil_l[:, l].tolist())
                NME_ocular_l_samples[l].extend(NME_ocular_l[:, l].tolist())
    
    NME_pupil_mean = 100*torch.FloatTensor(NME_pupil_samples).mean()
    NME_ocular_mean = 100*torch.FloatTensor(NME_ocular_samples).mean()
    NME_ocular_l_mean = 100*torch.FloatTensor(NME_ocular_l_samples).mean(1)
    
    return NME_pupil_mean, NME_ocular_mean, NME_ocular_l_mean, detection_timesteps_total



if __name__=='__main__':
    main()
    